# coding: utf-8
# pylint: disable= invalid-name
"""Training Library containing training routines."""
from __future__ import absolute_import

from . import rabit
from .core import EarlyStopException


def _get_callback_context(env):
    """return whether the current callback context is cv or train"""
    if env.model is not None and env.cvfolds is None:
        context = 'train'
    elif env.model is None and env.cvfolds is not None:
        context = 'cv'
    return context


def _fmt_metric(value, show_stdv=True):
    """format metric string"""
    if len(value) == 2:
        return '%s:%g' % (value[0], value[1])
    elif len(value) == 3:
        if show_stdv:
            return '%s:%g+%g' % (value[0], value[1], value[2])
        else:
            return '%s:%g' % (value[0], value[1])
    else:
        raise ValueError("wrong metric value")


def print_evaluation(period=1, show_stdv=True):
    """Create a callback that print evaluation result.

    We print the evaluation results every **period** iterations
    and on the first and the last iterations.

    Parameters
    ----------
    period : int
        The period to log the evaluation results

    show_stdv : bool, optional
         Whether show stdv if provided

    Returns
    -------
    callback : function
        A callback that print evaluation every period iterations.
    """
    def callback(env):
        """internal function"""
        if env.rank != 0 or len(env.evaluation_result_list) == 0 or period is False or period == 0:
            return
        i = env.iteration
        if (i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration):
            msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
            rabit.tracker_print('[%d]\t%s\n' % (i, msg))
    return callback


def record_evaluation(eval_result):
    """Create a call back that records the evaluation history into **eval_result**.

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    if not isinstance(eval_result, dict):
        raise TypeError('eval_result has to be a dictionary')
    eval_result.clear()

    def init(env):
        """internal function"""
        for k, _ in env.evaluation_result_list:
            pos = k.index('-')
            key = k[:pos]
            metric = k[pos + 1:]
            if key not in eval_result:
                eval_result[key] = {}
            if metric not in eval_result[key]:
                eval_result[key][metric] = []

    def callback(env):
        """internal function"""
        if len(eval_result) == 0:
            init(env)
        for k, v in env.evaluation_result_list:
            pos = k.index('-')
            key = k[:pos]
            metric = k[pos + 1:]
            eval_result[key][metric].append(v)
    return callback


def reset_learning_rate(learning_rates):
    """Reset learning rate after iteration 1

    NOTE: the initial learning rate will still take in-effect on first iteration.

    Parameters
    ----------
    learning_rates: list or function
        List of learning rate for each boosting round
        or a customized function that calculates eta in terms of
        current number of round and the total number of boosting round (e.g.
        yields learning rate decay)

        * list ``l``: ``eta = l[boosting_round]``
        * function ``f``: ``eta = f(boosting_round, num_boost_round)``

    Returns
    -------
    callback : function
        The requested callback function.
    """
    def get_learning_rate(i, n, learning_rates):
        """helper providing the learning rate"""
        if isinstance(learning_rates, list):
            if len(learning_rates) != n:
                raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
            new_learning_rate = learning_rates[i]
        else:
            new_learning_rate = learning_rates(i, n)
        return new_learning_rate

    def callback(env):
        """internal function"""
        context = _get_callback_context(env)

        if context == 'train':
            bst, i, n = env.model, env.iteration, env.end_iteration
            bst.set_param('learning_rate', get_learning_rate(i, n, learning_rates))
        elif context == 'cv':
            i, n = env.iteration, env.end_iteration
            for cvpack in env.cvfolds:
                bst = cvpack.bst
                bst.set_param('learning_rate', get_learning_rate(i, n, learning_rates))

    callback.before_iteration = True
    return callback


def early_stop(stopping_rounds, maximize=False, verbose=True):
    """Create a callback that activates early stoppping.

    Validation error needs to decrease at least
    every **stopping_rounds** round(s) to continue training.
    Requires at least one item in **evals**.
    If there's more than one, will use the last.
    Returns the model from the last iteration (not the best one).
    If early stopping occurs, the model will have three additional fields:
    ``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
    (Use ``bst.best_ntree_limit`` to get the correct value if ``num_parallel_tree``
    and/or ``num_class`` appears in the parameters)

    Parameters
    ----------
    stopp_rounds : int
       The stopping rounds before the trend occur.

    maximize : bool
        Whether to maximize evaluation metric.

    verbose : optional, bool
        Whether to print message about early stopping information.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    state = {}

    def init(env):
        """internal function"""
        bst = env.model

        if len(env.evaluation_result_list) == 0:
            raise ValueError('For early stopping you need at least one set in evals.')
        if len(env.evaluation_result_list) > 1 and verbose:
            msg = ("Multiple eval metrics have been passed: "
                   "'{0}' will be used for early stopping.\n\n")
            rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
        maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
        maximize_at_n_metrics = ('auc@', 'aucpr@' 'map@', 'ndcg@')
        maximize_score = maximize
        metric_label = env.evaluation_result_list[-1][0]
        metric = metric_label.split('-', 1)[-1]

        if any(metric.startswith(x) for x in maximize_at_n_metrics):
            maximize_score = True

        if any(metric.split(":")[0] == x for x in maximize_metrics):
            maximize_score = True

        if verbose and env.rank == 0:
            msg = "Will train until {} hasn't improved in {} rounds.\n"
            rabit.tracker_print(msg.format(metric_label, stopping_rounds))

        state['maximize_score'] = maximize_score
        state['best_iteration'] = 0
        if maximize_score:
            state['best_score'] = float('-inf')
        else:
            state['best_score'] = float('inf')

        if bst is not None:
            if bst.attr('best_score') is not None:
                state['best_score'] = float(bst.attr('best_score'))
                state['best_iteration'] = int(bst.attr('best_iteration'))
                state['best_msg'] = bst.attr('best_msg')
            else:
                bst.set_attr(best_iteration=str(state['best_iteration']))
                bst.set_attr(best_score=str(state['best_score']))
        else:
            assert env.cvfolds is not None

    def callback(env):
        """internal function"""
        score = env.evaluation_result_list[-1][1]
        if len(state) == 0:
            init(env)
        best_score = state['best_score']
        best_iteration = state['best_iteration']
        maximize_score = state['maximize_score']
        if (maximize_score and score > best_score) or \
                (not maximize_score and score < best_score):
            msg = '[%d]\t%s' % (
                env.iteration,
                '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
            state['best_msg'] = msg
            state['best_score'] = score
            state['best_iteration'] = env.iteration
            # save the property to attributes, so they will occur in checkpoint.
            if env.model is not None:
                env.model.set_attr(best_score=str(state['best_score']),
                                   best_iteration=str(state['best_iteration']),
                                   best_msg=state['best_msg'])
        elif env.iteration - best_iteration >= stopping_rounds:
            best_msg = state['best_msg']
            if verbose and env.rank == 0:
                msg = "Stopping. Best iteration:\n{}\n\n"
                rabit.tracker_print(msg.format(best_msg))
            raise EarlyStopException(best_iteration)
    return callback
